Nowadays, stroke represents one of the most common case of deaths, with aproximately 11% of total deaths, according to World Health Organisation.
The understanding of these worrying statistics led us to study the influence of certain factors in triggering a stroke. This project was designed to explore a dataset containing parameters like age,hypertension,body mass index,gender and a few more in order to determine wich pacient profile is exposed to a stroke.
Analyzing the actual statistics, we can conclude that studying the characteristics of the pacients could be the first step in understanding why stroke is one of the most common death causes. For the implementation, we used Python language, Spark framework, SparkSQL,Tensorflow and R language for some plots.
We started by reading and processing the dataset.To do this task, we used pyspark. We created a file (DataInit.py) that is used as a python library withing our project. This contains specific functions used to manipulate the dataset. The following code section is from DataInit.py and it is used to read and process the dataset. The get_data_frame() is used each time we need to get the dataset from other python files.
from pyspark.sql import SparkSession
from pyspark.sql.functions import count
spark1 = SparkSession.builder.appName('Stroke').getOrCreate()
data_frame = spark1.read.csv('/content/sample_data/healthcare-dataset-stroke-data.csv',inferSchema=True,header=True)
data_frame.createOrReplaceTempView("DATA")
def get_data_frame():
return data_frame
First step in EDA is to see information regarding: column names, column data types, irrelenvant/missing data, etc. We created a new file (StrokePlots.py) that imports the initial DataInit.py file
!pip install pyspark
!wget -q https://downloads.apache.org/spark/spark-3.0.1/spark-3.0.1-bin-hadoop3.2.tgz
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from pyspark.sql import SparkSession
from pyspark.sql.functions import count
from pyspark.sql.functions import col, isnan, when, count, lit
import DataInit
# Get data from DataInit file
data_frame = DataInit.get_data_frame()
data = data_frame.toPandas()
# Start EDA
print("*** Started Exploratory Data Analysis ***")
print("Show summary")
DataInit.data_summary_show(data_frame)
print("Show inconsistent data (i.e null, nan, etc.)")
data_frame.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in data_frame.columns]).show()
The code shows the following output:
Output image:
Next, we will plot the information for each column, to see more details about all of them.
We used pie plot for the columns with few distinct values (i.e 2,3). In my opinion, this type of plot is a good choice when we want to see distribution of a small number of distinct values. We also have the percentage available.
First one shows the distribution by ‘gender’. We can see that the value ‘Other’ is irrelevant. We will remove it further on data clean step.
data = data_frame.toPandas()
data['gender'].value_counts().plot.pie(autopct='%1.1f%%', startangle=90, explode=[0.1, 0.1, 0.1], figsize=(9,6), colors = ['#25BBAA', '#e3a59c'], textprops={'fontsize': 12})
Output image:
The next one shows the distribution by ‘stroke’. We can observer that a small part of the people from the dataset already had a stroke.
data = data_frame.toPandas()
data['stroke'].value_counts().plot.pie(autopct='%1.1f%%', startangle=90, explode=[0.1, 0.1, 0.1], figsize=(9,6), colors = ['#25BBAA', '#e3a59c'], textprops={'fontsize': 12})
Output image:
The next one shows the distribution by ‘heart_disease’.
data = data_frame.toPandas()
data['heart_disease'].value_counts().plot.pie(autopct='%1.1f%%', startangle=90, explode=[0.1, 0.1, 0.1], figsize=(9,6), colors = ['#25BBAA', '#e3a59c'], textprops={'fontsize': 12})
Output image:
The next one shows the distribution by ‘hypertension’.
data = data_frame.toPandas()
data['hypertension'].value_counts().plot.pie(autopct='%1.1f%%', startangle=90, explode=[0.1, 0.1, 0.1], figsize=(9,6), colors = ['#25BBAA', '#e3a59c'], textprops={'fontsize': 12})
Output image:
The next one shows the distribution by ‘ever_married’.
data = data_frame.toPandas()
data['ever_married'].value_counts().plot.pie(autopct='%1.1f%%', startangle=90, explode=[0.1, 0.1, 0.1], figsize=(9,6), colors = ['#25BBAA', '#e3a59c'], textprops={'fontsize': 12})
Output image:
The next one shows the distribution by ‘Residence_type’.
data = data_frame.toPandas()
data['Residence_type'].value_counts().plot.pie(autopct='%1.1f%%', startangle=90, explode=[0.1, 0.1, 0.1], figsize=(9,6), colors = ['#25BBAA', '#e3a59c'], textprops={'fontsize': 12})
Output image:
The next one shows the distribution by ‘work_type’. We used a countplot for this column. We can observe that a lot of people from the dataset are working in the private sector.
data = data_frame.toPandas()
chart = sns.countplot(x='work_type', data=data, palette=('#e3a59c','#E1503C','#25BBAA', '#17F9DF'))
chart.set_xticklabels(chart.get_xticklabels(), rotation=90)
Output image:
The next one shows the distribution by ‘smoking_status’. We used a countplot for this column. We can observe the column with ‘Unknown’ status. We will remove this records as we can not get any relevant conclusion from them.
data = data_frame.toPandas()
chart = sns.countplot(x='smoking_status', data=data, palette=('#e3a59c','#E1503C','#25BBAA', '#17F9DF'))
chart.set_xticklabels(chart.get_xticklabels(), rotation=90)
Output image:
The next one shows the distribution by ‘age’. We used a histplot for this column because this type of plot is good for showing the distribution of numerical values.
data = data_frame.toPandas()
sns.histplot(x='age', data=data, color="#25BBAA");
Output image:
The next one shows the distribution by ‘avg_glucose_level’. We used a histplot for this column because this type of plot is good for showing the distribution of numerical values.
data = data_frame.toPandas()
sns.histplot(x='avg_glucose_level', data=data, color="#25BBAA");
Output image:
The next one shows the distribution by ‘bmi’. We used a histplot for this column because this type of plot is good for showing the distribution of numerical values. We observe that the data is now showed correctly. After taking a closer look in the dataset, we see that few records have ‘N/A’ string on this column. For some reason, the standard functions we used to see the null/na values did not detect this.
data = data_frame.toPandas()
sns.histplot(x='bmi', data=data, color="#25BBAA");
Output image:
For this reason, we will remove these records, to show the correct
plot of ‘bmi’ column. We will remove these records later, also, on data
clean process. Output image:
Running again the distribution plot for ‘bmi’ collumn. We observe a total different output:
data = data_frame.toPandas()
sns.histplot(x='bmi', data=data, color="#25BBAA");
Output image:
The next step is cleaning up the dataset. As seen from the EDA plots, we will need to: - remove the records with ‘gender’ property = ‘Other’; - remove the records with ‘smoking_status’ property = ‘Unknown’; - remove the records with ‘bmi’ property = ‘N/A’; - drop column ‘id’;
We used this function that is implemented on DataInit.py.
# Return the cleaned up dataset
def data_cleaning(data_frame):
print("***Cleaning the dataset***:")
print("***Drop missing values***")
data_frame.na.drop()
data_frame.select('gender').distinct().rdd.map(lambda r: r[0]).collect()
spark1.sql("SELECT * FROM DATA WHERE GENDER = 'Other'").show()
print("***Remove gender == 'Other'***:")
data_frame= data_frame.filter(data_frame.gender != 'Other')
print("***Remove smoking_status == 'Unknown'***")
data_frame.select('smoking_status').distinct().rdd.map(lambda r: r[0]).collect()
data_frame = data_frame.filter(data_frame.smoking_status != 'Unknown')
print("***Remove bmi == 'N/A'***")
data_frame = data_frame.filter(data_frame.bmi != 'N/A')
print("***Drop column 'id'***")
data_frame= data_frame.drop(data_frame['id'])
print("***Finished cleaning the dataset***:")
return data_frame
Running it from the StrokePlots.py file, we will get the following
output: Output image:
pip install pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col,isnan, when, count
from pyspark.sql.types import StringType, StructField, StructType
from pyspark.ml.feature import StringIndexer, OneHotEncoder
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
!wget -q https://downloads.apache.org/spark/spark-3.0.1/spark-3.0.1-bin-hadoop3.2.tgz
spark1 = SparkSession.builder.appName('Stroke').getOrCreate()
df = spark1.read.csv('/content/sample_data/final_dataset.csv',inferSchema=True,header=True)
df.printSchema()
df.show()
df.columns
df.count()
df.summary().show()
df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns]
).show()
#string indexing
SI_gender= StringIndexer(inputCol='gender',outputCol='gender_index')
SI_ever_married= StringIndexer(inputCol='ever_married',outputCol='ever_married_index')
SI_work_type= StringIndexer(inputCol='work_type',outputCol='work_type_index')
SI_residence_type= StringIndexer(inputCol='Residence_type',outputCol='Residence_type_index')
SI_smoking_status= StringIndexer(inputCol='smoking_status',outputCol='smoking_status_index')
df= SI_gender.fit(df).transform(df)
df= SI_ever_married.fit(df).transform(df)
df= SI_work_type.fit(df).transform(df)
df= SI_residence_type.fit(df).transform(df)
df= SI_smoking_status.fit(df).transform(df)
df.select('gender', 'gender_index', 'ever_married', 'ever_married_index','work_type','work_type_index','Residence_type','Residence_type_index','smoking_status','smoking_status_index').show(10)
#classification
df= df.drop(df['gender'])
df=df.drop(df['ever_married'])
df=df.drop(df['work_type'])
df=df.drop(df['smoking_status'])
df=df.drop(df['Residence_type'])
df=df.drop(df['bmi'])
data =df.toPandas()
X = data.drop('stroke', axis=1)
y = data['stroke']
data.size
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=101)
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
print(X_train.shape)
#tensorflow &neural network with 2 layers
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Dropout
model = Sequential()
model.add(Dense(units=9, activation='relu', input_shape=(9,))) # First layer
model.add(Dense(units=9, activation='relu')) # Second layer
model.add(Dense(units=1, activation='sigmoid')) # Output layer
model.summary()
model.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
from tensorflow.keras.callbacks import EarlyStopping
early_stop = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=25)
history = model.fit(x=X_train, y=y_train,epochs=100,validation_data=(X_test, y_test), verbose=1, callbacks=[early_stop] )
print(type(history.history))
print(history.history['val_loss'][:5])
print(history.history['loss'][:5])
plt.plot(history.history['loss'], c='cornflowerblue', label='training loss')
plt.plot(history.history['val_loss'], c='goldenrod', label='validation loss')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss');
model = Sequential()
model.add(Dense(units=9, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(units=9, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(units=1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam')
history = model.fit(x=X_train, y=y_train, epochs=200, validation_data=(X_test, y_test), verbose=1, callbacks=[early_stop] )
plt.plot(history.history['loss'], c='cornflowerblue', label='training loss')
plt.plot(history.history['val_loss'], c='goldenrod', label='validation loss')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss');
Neural network with dropout layers
Neural network without droput layers
As expected, this plot shows us that especially senior patiens (60+) are more exposed to a stroke, the principal ages at which patients suffered a stroke were 79 and 80 years. After that, we can see that the incidence is decreasing, but this is due to the fact that only a few patients over the age of 82 have been included in the dataset.
library(tidyverse)
library(ggplot2)
df<-read_csv("final_dataset.csv")
df1<-df[df[, "stroke"]==1, ]
df1 %>%
group_by(age) %>%
summarize(total_strokes=sum(stroke)) %>%
ggplot(aes(x=age, y=total_strokes)) +
geom_bar(stat="identity", fill="steelblue")+
geom_text(aes(label=total_strokes), vjust=1.6, color="black", size=3)+
ggtitle("Total stroke cases by age") +
xlab("Age") + ylab("Total Strokes")+
theme(plot.title = element_text(hjust = 0.5), panel.background = element_blank())
BMI (body mass index) is a parameter obtained by patient’s weight and and height. A normal value is situated between 18.5 and 24.9. A person with a value grater than 25 is considered to be overweight and grater than 30, obese. Here you can see the distribution of bmi for our analysed patients:
ggplot(df, aes(x = bmi, y = age)) +
geom_point(color = "chocolate4") +
labs(x = "BMI", y = "Age") + scale_x_discrete(breaks = round(seq(0, 70, by = 5),0))+ggtitle("Body mass index by age")+
theme(axis.title = element_text(color = "sienna", size = 15, face = "bold"),
axis.title.y = element_text(face = "bold.italic"),plot.title = element_text(lineheight = 8, size = 16))+theme_light()
Next plots illustrate the average levels of glucose for the patiens registered. A normal value for this parameter should be between 90 and 110. Values greater than 110 can indicate diabetes. For a better visualization, we used viridis color palletes and below you can find 5 representations:
library(ggplot2)
gpl<- ggplot(df, aes(x = avg_glucose_level, y = age,color=age)) +
geom_point() +
labs(x = "Average glucose level", y = "Age", color = "Age:")+
scale_x_continuous(breaks = round(seq(10, 300, by = 50),0)) + scale_y_continuous(breaks = round(seq(10, 300, by = 50),0))
gpl + scale_color_continuous()
plot1 <- gpl + scale_color_viridis_c() + ggtitle("'magma' (default)")
plot2 <- gpl + scale_colour_viridis_c(option = "inferno") + ggtitle("'inferno'")
plot3 <- gpl + scale_colour_viridis_c(option = "plasma") + ggtitle("'plasma'")
plot4 <- gpl + scale_colour_viridis_c(option = "cividis") + ggtitle("'rocket'")
library(patchwork)
(plot1 + plot2 + plot3 + plot4) * theme(legend.position = "bottom")
Below you can see a 3D scatter plot illustrating the link between BMI, age and the average glucose level.
library(tidyverse)
library(plotly)
plot <- plot_ly(
df[df[, "stroke"]==1, ], x = ~bmi, y = ~avg_glucose_level, z = ~age,
color = ~Residence_type, colors = c('#BF382A', '#0C4B8E')
) %>%
add_markers() %>%
layout(
scene = list(xaxis = list(title = 'Body mass index'),
yaxis = list(title = 'Avg_glucose_level'),
zaxis = list(title = 'Age'))
)
plot
Next we selected a sample of 10 patiens who suffered a stroke. As hypertension is frequently related to factors like age and smoking status,this plot shows the link between these parameters:
df1<- df[df[, "stroke"]==1, ]
library(plotly)
df1$hypertension<-as.factor(df1$hypertension)
ggplot(head(df1,10), aes(x =rownames(head(df1,10)), y = smoking_status)) +
geom_point(aes(color = hypertension, size = age), alpha = 0.5) +
scale_color_manual(values = c("#00AFBB", "#FC4E07")) + xlab("Sample")+
scale_size(range = c(0.5, 12))
The following plot shows the working status for the patients. As we can see, the “never worked” status belongs to young people and the number of people with their own business seems to increase with the age, starting with 25 years.
library(ggridges)
library(ggplot2)
ggplot(df, aes(x = age, y = work_type, fill = work_type)) +
geom_density_ridges() +
theme_ridges() +
theme(legend.position = "none")+
ggtitle("Ridgeline plot working type by age ") +
xlab("Age") + ylab("Working status")+
theme(plot.title = element_text(hjust = 0.5), panel.background = element_blank())
Last but not least we made an R function in order to determine the percent of people from our dataset who suffered a stroke and how is this percent distributed between sexes. Here you can find the result:
library(ggplot2)
percent <- function(x, digits = 2, format = "f", ...) {
paste0(formatC(100 * x, format = format, digits = digits, ...), "%")
}
df %>%
group_by(smoking_status,gender) %>%
mutate(freq = percent(sum(stroke) / nrow(df))) %>%
ggplot(aes(x=smoking_status, y=freq,fill=gender)) +
geom_bar(stat="identity", position="dodge")+
geom_text(aes(label=freq), vjust=1.6, color="black", size=3)+
ggtitle("Stroke rate per gender determined by smoking status") +
xlab("Smoking_status") + ylab("Stroke rate")+
theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))+
theme(plot.title = element_text(hjust = 0.5), panel.background = element_blank())
https://www.who.int/news-room/fact-sheets/detail/the-top-10-causes-of-death https://www.kaggle.com/datasets/fedesoriano/stroke-prediction-dataset https://www.data-to-viz.com/ https://r-graph-gallery.com/320-the-basis-of-bubble-plot.html https://cran.r-project.org/web/packages/viridis/vignettes/intro-to-viridis.html